Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 12, 2026

(join|split)_dims

I've changed join_dims, to be a true mirror of split_dims. The JoinDims Op itself already was, and if we make the helper also behave like the Op we can simplify logic elsewhere.

The signature is now join(x, axis: int=0, n_axes: int | None = None).

The main change is that:

  1. join_dims(x, n_axes=0), implies an expand_dims. This is the mirror of split_dims(x, split_shape=()) implying a squeeze.
  2. join_dims treats scalar (0d) inputs as if they had one dimension, when reasoning about axis. Otherwise axis=0/-1 doesn't make sense. This is analogous to expand_dims.
    Also:
  3. split_dims axis no longer accepts None, the default is 0.

It also has the pleasant side-effect that you can't specify non-consecutive axis, which the other syntax would suggest is possible (before erroring out). You can only fail with axis or n_axes too large.

(un)pack

Rename (un)pack axes argument to keep_axes.
Also:

  • Allow default None on unpack
  • Allow default of None to work even with >1d inputs Reverted
  • Fix case with single input

Cherry picked from #1806
Closes #1835

@ricardoV94 ricardoV94 force-pushed the tweaks-to-reshape-ops branch from 03342fc to e8b8fb6 Compare January 12, 2026 18:16
@ricardoV94 ricardoV94 added bug Something isn't working Op implementation labels Jan 12, 2026
@ricardoV94 ricardoV94 changed the title Tweaks to reshape ops Tweaks to reshape Ops Jan 12, 2026
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved with small suggestions.

I still prefer a list of axes as the public API for join_dims over start_axis, n_axes, but I'm not willing to have a fight about it.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 12, 2026

Approved with small suggestions.

I still prefer a list of axes as the public API for join_dims over start_axis, n_axes, but I'm not willing to have a fight about it.

Counterargument is that for the internal uses we do have it was more awkward to do it like this and and you can't really do non-consecutive axis so there's also no benefit.

So we were asking users to transform their range into a list of axis, just so we can go and undo that because we want the range anyway

@jessegrabowski
Copy link
Member

jessegrabowski commented Jan 12, 2026

Yeah you are right, that's why I don't want to fight.

But I think from a user perspective, it's more natural to pass the list. I'm remembering how confused I was by the pt.split API to form this opinion.

@ricardoV94
Copy link
Member Author

What about split? Seems completely different thing

@jessegrabowski
Copy link
Member

It requires two arguments, splits_size and n_splits, when I would only have expected one. Similarly, I would only expect one argument for join_dims -- the dims I want to join.

But I looked it up to type this message and I see that n_splits is actually optional. But it also doesn't have a docstring, so that probably didn't help me back then.

@ricardoV94
Copy link
Member Author

n_splits was needed before static shape, as tensor_variable([1, 2, 3]) wasn't known to have length 3 without introspecting it.

It could have been avoided by expecting a sequence of scalars but they didn't go with that.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 12, 2026

For join_dims I think list of axis is fine if we support arbitrary axis.

I strongly believe that if we ask for sequence of axis the user will expect they can be non-consecutive. So the API immediately forces the user to think about the valid cases. It's a range. We could rename it to join_dims_range?

I'm still partial to a more flexible join_dims where we do the transpose to align the the dims to be joined for the user. But note that wasn't needed for our purposes of pack.

Although then you need to decide where they end up, as there's no longer a reasonable default). xarray puts it at the end with stack but positions don't matter as much for them.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 12, 2026

Ah forgot to say why I did it though. You can't obtain the expand_dims edge case with the axis argument. There's no equivalent for join_dims(x, start_axis=2, n_axes=0)

Unless you offer the output_axis argument that decides where the joined axis goes

@jessegrabowski
Copy link
Member

I think this is fine as-is. We can revisit it if we really want a more complex operation. I think re-arranging the dimensions is out of scope for this Op. We're slowly just building back up to dimshuffle.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 12, 2026

I think re-arranging the dimensions is out of scope for this Op. We're slowly just building back up to dimshuffle.

It would be done by the helper, not the Op, So still combining DimShuffle and JoinDims, just doing the boilerplate work for the user. API would be something like: #1844

Agree with leaving it to another PR, but we have to decide on it soon, as it will be breaking change (this one already is)

Also:
* Allow default `None` on unpack
Comment on lines 51 to 55
2,
1,
3,
4,
5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

peak ruff

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just need to take the last comma

Mainly, joining 0 axes is equivalent to inserting a new dimension. This is the mirror of how splitting a single axis into an empty shape is equivalent to squeezing it.
@ricardoV94 ricardoV94 force-pushed the tweaks-to-reshape-ops branch from b7c3a87 to 802f251 Compare January 12, 2026 22:45
@ricardoV94 ricardoV94 merged commit b1678fd into pymc-devs:main Jan 12, 2026
11 of 12 checks passed
@ricardoV94 ricardoV94 deleted the tweaks-to-reshape-ops branch January 12, 2026 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working Op implementation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix unpack sharp edges

2 participants